# http://proceedings.mlr.press/v101/huang19a/huang19a.pdf
# https://www.researchgate.net/publication/220875351_Generative_Models_for_Labeling_Multi-object_Configurations_in_Images
# https://www.tensorflow.org/datasets/catalog/open_images_v4
# Auto-Encoding Progressive Generative Adversarial Networks For 3D Multi Object Scenes
TODO
datasets to experiment
%config Completer.use_jedi = False
from ipywidgets import IntProgress
import matplotlib.pyplot as plt
from tensorflow.keras import layers, losses
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import logging
import tensorflow_datasets as tfds
import pandas as pd
from tqdm import tqdm_notebook as tqdm
from sklearn.mixture import GaussianMixture
import os
seed = 1
np.random.seed(1)
tf.random.set_seed(1)
batch_size = 32
epochs = 10
dataset_name = 'kitti'
if dataset_name == 'bdd100k':
train_ds = tf.keras.preprocessing\
.image_dataset_from_directory(directory='../data/bdd100k/images/10k/train1/',batch_size=batch_size)# train
test_ds = tf.keras.preprocessing\
.image_dataset_from_directory(directory='../data/bdd100k/images/10k/test1/',batch_size=batch_size) # test
validation_ds = tf.keras.preprocessing\
.image_dataset_from_directory(directory='../data/bdd100k/images/10k/val1/',batch_size=batch_size) # validation
elif dataset_name in ['flic','fashion_mnist','mnist','kitti']:
train_ds,test_ds = tfds.load(name=dataset_name,split=['train', 'test']\
,as_supervised=False,download=True)
validation_ds = test_ds
elif dataset_name in ['wider_face']:
train_ds,test_ds,validation_ds = tfds.load(name=dataset_name,split=['train', 'test','validation']\
,as_supervised=False,download=True)
else:
raise ValueError(f'Unhandled dataset {dataset_name}')
if dataset_name == 'bdd100k':
dims = [x[0].get_shape().as_list() for x in train_ds]
dims_df= pd.DataFrame.from_records(data=dims,columns=['batch','height','width','depth'])
else:
dims = [x['image'].get_shape().as_list() for x in train_ds]
dims_df= pd.DataFrame.from_records(data=dims,columns=['height','width','depth'])
dims_df.describe()
| height | width | depth | |
|---|---|---|---|
| count | 6347.000000 | 6347.000000 | 6347.0 |
| mean | 374.481960 | 1240.112494 | 3.0 |
| std | 1.447946 | 5.220926 | 0.0 |
| min | 370.000000 | 1224.000000 | 3.0 |
| 25% | 375.000000 | 1242.000000 | 3.0 |
| 50% | 375.000000 | 1242.000000 | 3.0 |
| 75% | 375.000000 | 1242.000000 | 3.0 |
| max | 375.000000 | 1242.000000 | 3.0 |
m = 20
height = int(min(dims_df['height'])/m)*m
width = int(min(dims_df['width'])/m)*m
# height = 2**(int(np.log2(min(dims_df['height']))))
# width = 2**(int(np.log2(min(dims_df['width']))))
depth = min(dims_df['depth'])
height,width = min(height,width),min(height,width)
height,width,depth
(360, 360, 3)
for t in train_ds.take(3):
print(t.keys())
dict_keys(['image', 'image/file_name', 'objects']) dict_keys(['image', 'image/file_name', 'objects']) dict_keys(['image', 'image/file_name', 'objects'])
if dataset_name == 'bdd100k':
train_ds = train_ds.map(lambda x0,x1: x0/255.)
test_ds = test_ds.map(lambda x0,x1: x0/255.)
validation_ds = validation_ds.map(lambda x0,x1: x0/255.)
else:
train_ds = train_ds.map(lambda x: tf.image.resize(images=tf.cast(x['image'],dtype=tf.float32)/255.,\
size=[height,width]))
train_ds = train_ds.batch(batch_size,drop_remainder=True)
###
test_ds = test_ds.map(lambda x: tf.image.resize(tf.cast(x['image'],dtype=tf.float32)/255.,\
size=[height,width]))
test_ds = test_ds.batch(batch_size,drop_remainder=True)
###
validation_ds = validation_ds.map(lambda x: tf.image.resize(tf.cast(x['image'],dtype=tf.float32)/255.\
,size=[height,width]))
validation_ds = validation_ds.batch(batch_size,drop_remainder=True)
###
train_ds_double_zipped = tf.data.Dataset.zip(datasets=(train_ds,train_ds))
test_ds_double_zipped = tf.data.Dataset.zip(datasets=(test_ds,test_ds))
validation_ds_double_zipped = tf.data.Dataset.zip(datasets=(validation_ds,validation_ds))
latent_dim = 128
class CAE(tf.keras.Model):
"""Convolutional variational autoencoder."""
def __init__(self, latent_dim):
super(CAE, self).__init__()
self.latent_dim = latent_dim
self.logger = logging.getLogger('CAE')
self.encoder = tf.keras.Sequential(name='encoder',layers=\
[
tf.keras.layers.InputLayer(input_shape=(height, width, depth)),
tf.keras.layers.Conv2D(
filters=32, kernel_size=3, strides=(2, 2), activation='relu'),
tf.keras.layers.Conv2D(
filters=64, kernel_size=3, strides=(2, 2), activation='relu'),
tf.keras.layers.Flatten(),
# No activation
tf.keras.layers.Dense(latent_dim),
]
)
self.decoder = tf.keras.Sequential(name='decoder',layers=\
[
tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
tf.keras.layers.Dense(units=int(height/4) * int(width/4) * 32, activation=tf.nn.relu),
tf.keras.layers.Reshape(target_shape=(int(height/4), int(width/4), 32)),
tf.keras.layers.Conv2DTranspose(
filters=64, kernel_size=3, strides=2, padding='same',
activation='relu'),
tf.keras.layers.Conv2DTranspose(
filters=32, kernel_size=3, strides=2, padding='same',
activation='relu'),
# No activation
tf.keras.layers.Conv2DTranspose(
filters=depth, kernel_size=3, strides=1, padding='same'),
]
)
self.encoder.summary()
self.decoder.summary()
def call(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
cae = CAE(latent_dim)
cae.compile(optimizer='adam', loss=losses.MeanSquaredError())
Model: "encoder" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d (Conv2D) (None, 179, 179, 32) 896 _________________________________________________________________ conv2d_1 (Conv2D) (None, 89, 89, 64) 18496 _________________________________________________________________ flatten (Flatten) (None, 506944) 0 _________________________________________________________________ dense (Dense) (None, 128) 64888960 ================================================================= Total params: 64,908,352 Trainable params: 64,908,352 Non-trainable params: 0 _________________________________________________________________ Model: "decoder" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_1 (Dense) (None, 259200) 33436800 _________________________________________________________________ reshape (Reshape) (None, 90, 90, 32) 0 _________________________________________________________________ conv2d_transpose (Conv2DTran (None, 180, 180, 64) 18496 _________________________________________________________________ conv2d_transpose_1 (Conv2DTr (None, 360, 360, 32) 18464 _________________________________________________________________ conv2d_transpose_2 (Conv2DTr (None, 360, 360, 3) 867 ================================================================= Total params: 33,474,627 Trainable params: 33,474,627 Non-trainable params: 0 _________________________________________________________________
model_file_path = f'./models/cae_dataset_{dataset_name}_z_dim_{latent_dim}_data_dim_{height}x{width}x{depth}'
print(f'model path = {model_file_path}')
model path = ./models/cae_dataset_kitti_z_dim_128_data_dim_360x360x3
if os.path.exists(model_file_path):
print('loading saved model')
cae = tf.keras.models.load_model(filepath=model_file_path)
else:
print('building model')
# use checkpoints to save model fitting progress
# https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ModelCheckpoint
checkpoint_filepath = './checkpoint'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=True,
monitor='val_loss',
mode='max',
save_best_only=True)
# Model weights are saved at the end of every epoch, if it's the best seen
# so far.
cae.fit(x=train_ds_double_zipped,validation_data=test_ds_double_zipped,epochs=epochs,\
callbacks=[model_checkpoint_callback],use_multiprocessing=True)
# The model weights (that are considered the best) are loaded into the model.
cae.load_weights(checkpoint_filepath)
print('saving model')
cae.save(filepath=model_file_path)
loading saved model
# create valdation dataset tensor
for e in validation_ds.take(1):
initial_state = tf.zeros(dtype=tf.float32,shape=e.shape)
validation_ds_tensor = validation_ds.\
reduce(initial_state=initial_state,reduce_func=lambda x,y: tf.concat(values=[x,y],axis=0))
validation_ds_tensor = validation_ds_tensor[batch_size:] # drop dummy initial state
# calculate loss, can be compare over different dataset due to data scaling from 0 to 1
y_predicted = cae.predict(validation_ds)
cae_loss = cae.loss(y_pred=y_predicted,y_true=validation_ds_tensor).numpy()
print(f'CAE loss for dataset {dataset_name} = {np.round(cae_loss,4)}')
CAE loss for dataset kitti = 0.042899999767541885
# plot decoded images
for batch in validation_ds.take(2):
z = cae.encoder(batch).numpy()
decoded_imgs = cae.decoder(z).numpy()
for i in range(batch.shape[0]):
fig, (ax1, ax2) = plt.subplots(1, 2)
ax1.imshow(batch[i])
ax2.imshow(decoded_imgs[i])
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/ipykernel_launcher.py:8: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
# getting z tensor
z_tensor = None
inf_or_unknown_cardinality = ((test_ds.cardinality()==tf.data.INFINITE_CARDINALITY)\
or (test_ds.cardinality() == tf.data.UNKNOWN_CARDINALITY)).numpy()
batches = test_ds.cardinality().numpy() if not inf_or_unknown_cardinality else 500
with tqdm(total=batches) as pbar:
for batch in test_ds.take(batches):
z = cae.encoder(batch).numpy()
if z_tensor is None:
z_tensor = tf.convert_to_tensor(z)
else:
z_tensor = tf.concat([z_tensor,tf.convert_to_tensor(z)],axis=0)
pbar.update(1)
#print(f'z shape {z.shape}')
# decoded_imgs = cae.decoder(z).numpy()
# #print(f'decoded images shape {decoded_imgs[0].shape}')
# plt.imshow(batch[0])
# plt.show()
# plt.imshow(decoded_imgs[0])
# plt.show()
z_tensor.shape
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/ipykernel_launcher.py:8: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0 Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
TensorShape([704, 128])
z_np= z_tensor.numpy()
n_z = z_np.shape[0]
n_z_train = int(0.8*n_z)
z_train = z_np[:n_z_train]
z_test = z_np[n_z_train:]
random_state = 1
reg_covar = 0.1
logps = []
k_values = [1,10,20,50,70,80,100,200]
cov_types = ['diag','cov']
for k in k_values:
for cov_type in ['diag','full']:
try:
gm_fit = GaussianMixture(n_components=k,covariance_type=cov_type,random_state=random_state,\
reg_covar=reg_covar).fit(z_train)
logp_gm = gm_fit.score(X=z_test)
print(f'For Gaussin Mixture with k = {k} and cov type {cov_type}, logp = {logp_gm} ')
logps.append({'k':k,'cov_type':cov_type,'logp':logp_gm})
print('############## ')
except Exception as e:
print(f'Catched expection {e} ')
For Gaussin Mixture with k = 1 and cov type diag, logp = -262.341110386977 ############## For Gaussin Mixture with k = 1 and cov type full, logp = -11.438871900903225 ############## For Gaussin Mixture with k = 10 and cov type diag, logp = -170.6917955997237 ############## For Gaussin Mixture with k = 10 and cov type full, logp = -1.327926347027144 ############## For Gaussin Mixture with k = 20 and cov type diag, logp = -137.54189665032573 ############## For Gaussin Mixture with k = 20 and cov type full, logp = -0.37018736953522263 ############## For Gaussin Mixture with k = 50 and cov type diag, logp = -113.42024058446295 ############## For Gaussin Mixture with k = 50 and cov type full, logp = -6.77506646715561 ############## For Gaussin Mixture with k = 70 and cov type diag, logp = -103.88851239114166 ############## For Gaussin Mixture with k = 70 and cov type full, logp = -13.593873988321024 ############## For Gaussin Mixture with k = 80 and cov type diag, logp = -101.45960509411145 ############## For Gaussin Mixture with k = 80 and cov type full, logp = -16.753749716347343 ############## For Gaussin Mixture with k = 100 and cov type diag, logp = -100.49451496007462 ############## For Gaussin Mixture with k = 100 and cov type full, logp = -28.776018292355722 ############## For Gaussin Mixture with k = 200 and cov type diag, logp = -97.06446524302851 ############## For Gaussin Mixture with k = 200 and cov type full, logp = -65.05804713056555 ##############
logps_df = pd.DataFrame.from_records(data=logps)
logps_df.sort_values(by='logp',ascending=False).reset_index()
| index | k | cov_type | logp | |
|---|---|---|---|---|
| 0 | 5 | 20 | full | -0.370187 |
| 1 | 3 | 10 | full | -1.327926 |
| 2 | 7 | 50 | full | -6.775066 |
| 3 | 1 | 1 | full | -11.438872 |
| 4 | 9 | 70 | full | -13.593874 |
| 5 | 11 | 80 | full | -16.753750 |
| 6 | 13 | 100 | full | -28.776018 |
| 7 | 15 | 200 | full | -65.058047 |
| 8 | 14 | 200 | diag | -97.064465 |
| 9 | 12 | 100 | diag | -100.494515 |
| 10 | 10 | 80 | diag | -101.459605 |
| 11 | 8 | 70 | diag | -103.888512 |
| 12 | 6 | 50 | diag | -113.420241 |
| 13 | 4 | 20 | diag | -137.541897 |
| 14 | 2 | 10 | diag | -170.691796 |
| 15 | 0 | 1 | diag | -262.341110 |